{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc76523c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp a2wav"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2939fb1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| exporti\n",
    "from vocos import Vocos\n",
    "import torch\n",
    "import torchaudio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed1dde62",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class Vocoder:\n",
    "    def __init__(self, repo_id=\"charactr/vocos-encodec-24khz\"):\n",
    "        self.vocos = Vocos.from_pretrained(repo_id).cuda()\n",
    "    \n",
    "    def is_notebook(self):\n",
    "        try:\n",
    "            return get_ipython().__class__.__name__ == \"ZMQInteractiveShell\"\n",
    "        except:\n",
    "            return False\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def decode(self, atoks):\n",
    "        if len(atoks.shape) == 3:\n",
    "            b,q,t = atoks.shape\n",
    "            atoks = atoks.permute(1,0,2)\n",
    "        else:\n",
    "            q,t = atoks.shape\n",
    "        \n",
    "        features = self.vocos.codes_to_features(atoks)\n",
    "        bandwidth_id = torch.tensor({2:0,4:1,8:2}[q]).cuda()\n",
    "        return self.vocos.decode(features, bandwidth_id=bandwidth_id)\n",
    "        \n",
    "    def decode_to_file(self, fname, atoks):\n",
    "        audio = self.decode(atoks)\n",
    "        torchaudio.save(fname, audio.cpu(), 24000)\n",
    "        if self.is_notebook():\n",
    "            from IPython.display import display, HTML, Audio\n",
    "            display(HTML(f'<a href=\"{fname}\" target=\"_blank\">Listen to {fname}</a>'))\n",
    "        \n",
    "    def decode_to_notebook(self, atoks):\n",
    "        from IPython.display import display, HTML, Audio\n",
    "\n",
    "        audio = self.decode(atoks)\n",
    "        display(Audio(audio.cpu().numpy(), rate=24000))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
